import random
import time
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from experiments.utils import load_dataset_safely, seed_everything
from experiments.components import COMPONENT_MAP, COMPONENT_META  # import your maps

# --- GPU replacements (optional if RAPIDS/cuML is installed) ---
try:
    from cuml.svm import SVC as cuSVC
    from cuml.linear_model import LogisticRegression as cuLogReg
    GPU_AVAILABLE = True
except ImportError:
    GPU_AVAILABLE = False


def sample_pipeline(rng):
    """
    Randomly build a pipeline from COMPONENT_MAP using COMPONENT_META grammar.
    """
    steps = []
    used_roles = set()
    
    while True:
        # filter candidates based on grammar
        candidates = [
            name for name, meta in COMPONENT_META.items()
            if (
                (meta["role"] == "transformer" and "estimator" not in used_roles)
                or (meta["role"] == "estimator" and "estimator" not in used_roles)
                or meta["role"] == "terminator"
            )
        ]
        choice = rng.choice(candidates)
        if choice == "END_PIPELINE":
            break
        comp = COMPONENT_MAP[choice]

        # swap to GPU versions if available
        if GPU_AVAILABLE:
            if choice.startswith("SVC("):
                comp = cuSVC()  # simple GPU fallback
            elif choice.startswith("LogisticRegression("):
                comp = cuLogReg(max_iter=1000)

        steps.append((choice, comp))
        used_roles.add(COMPONENT_META[choice]["role"])
        if COMPONENT_META[choice]["role"] == "estimator":
            break  # once we have estimator, stop
    
    return Pipeline(steps)


def random_search_multi_component(dataset_name: str, n_iter: int = 50, random_state: int = 42):
    seed_everything(random_state)
    data, msg = load_dataset_safely(dataset_name)
    if data is None:
        raise RuntimeError(msg)

    X_train, y_train = data["X_train"], data["y_train"]
    X_val, y_val = data["X_val"], data["y_val"]

    rng = np.random.RandomState(random_state)
    best_score, best_pipeline = 0.0, None

    start = time.time()

    for i in range(n_iter):
        pipe = sample_pipeline(rng)
        try:
            pipe.fit(X_train, y_train)
            preds = pipe.predict(X_val)
            score = accuracy_score(y_val, preds)
        except Exception as e:
            score = 0.0
            print(f"[Trial {i}] Failed: {e}")
        
        if score > best_score:
            best_score, best_pipeline = score, pipe

        print(f"[Trial {i}] Score={score:.4f} | Pipeline={pipe}")

    duration = time.time() - start
    return {
        "val_score": best_score,
        "best_pipeline": str(best_pipeline),
        "time_sec": duration,
    }


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="iris")
    p.add_argument("--n_iter", type=int, default=30)
    args = p.parse_args()

    res = random_search_multi_component(args.dataset, args.n_iter)
    print(res)
